{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-0.04618305, -0.0019841 , -0.02022721,  0.04200636], dtype=float32)"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import gym\n",
    "\n",
    "\n",
    "#定义环境\n",
    "class MyWrapper(gym.Wrapper):\n",
    "\n",
    "    def __init__(self):\n",
    "        env = gym.make('CartPole-v1')\n",
    "        super().__init__(env)\n",
    "        self.env = env\n",
    "\n",
    "    def reset(self):\n",
    "        state, _ = self.env.reset()\n",
    "        return state\n",
    "\n",
    "    def step(self, action):\n",
    "        state, reward, done, _, info = self.env.step(action)\n",
    "        return state, reward, done, info\n",
    "\n",
    "\n",
    "env = MyWrapper()\n",
    "\n",
    "env.reset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<stable_baselines3.ppo.ppo.PPO at 0x7f298b62afd0>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "from stable_baselines3 import PPO\n",
    "from stable_baselines3.common.torch_layers import BaseFeaturesExtractor\n",
    "\n",
    "\n",
    "#自定义特征抽取层\n",
    "class CustomCNN(BaseFeaturesExtractor):\n",
    "\n",
    "    def __init__(self, observation_space, hidden_dim):\n",
    "        super().__init__(observation_space, hidden_dim)\n",
    "\n",
    "        self.sequential = torch.nn.Sequential(\n",
    "\n",
    "            #[b, 4, 1, 1] -> [b, h, 1, 1]\n",
    "            torch.nn.Conv2d(in_channels=observation_space.shape[0],\n",
    "                            out_channels=hidden_dim,\n",
    "                            kernel_size=1,\n",
    "                            stride=1,\n",
    "                            padding=0),\n",
    "            torch.nn.ReLU(),\n",
    "\n",
    "            #[b, h, 1, 1] -> [b, h, 1, 1]\n",
    "            torch.nn.Conv2d(hidden_dim,\n",
    "                            hidden_dim,\n",
    "                            kernel_size=1,\n",
    "                            stride=1,\n",
    "                            padding=0),\n",
    "            torch.nn.ReLU(),\n",
    "\n",
    "            #[b, h, 1, 1] -> [b, h]\n",
    "            torch.nn.Flatten(),\n",
    "\n",
    "            #[b, h] -> [b, h]\n",
    "            torch.nn.Linear(hidden_dim, hidden_dim),\n",
    "            torch.nn.ReLU(),\n",
    "        )\n",
    "\n",
    "    def forward(self, state):\n",
    "        b = state.shape[0]\n",
    "        state = state.reshape(b, -1, 1, 1)\n",
    "        return self.sequential(state)\n",
    "\n",
    "\n",
    "model = PPO('CnnPolicy',\n",
    "            env,\n",
    "            policy_kwargs={\n",
    "                'features_extractor_class': CustomCNN,\n",
    "                'features_extractor_kwargs': {\n",
    "                    'hidden_dim': 8\n",
    "                },\n",
    "            },\n",
    "            verbose=0)\n",
    "\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/anaconda3/envs/pt39/lib/python3.9/site-packages/stable_baselines3/common/evaluation.py:67: UserWarning: Evaluation environment is not wrapped with a ``Monitor`` wrapper. This may result in reporting modified episode lengths and rewards, if other wrappers happen to modify these. Consider wrapping environment first with ``Monitor`` wrapper.\n",
      "  warnings.warn(\n",
      "[W NNPACK.cpp:53] Could not initialize NNPACK! Reason: Unsupported hardware.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(21.1, 9.289241088485108)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from stable_baselines3.common.evaluation import evaluate_policy\n",
    "\n",
    "#测试\n",
    "evaluate_policy(model, env, n_eval_episodes=10, deterministic=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "03671d25c8ca4e04a614ad63f371002a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#训练\n",
    "model.learn(total_timesteps=2_0000, progress_bar=True)\n",
    "\n",
    "model.save('models/自定义特征抽取层')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "id": "zpz8kHlt_a_m"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(522.7, 428.98812337872477)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = PPO.load('models/自定义特征抽取层')\n",
    "\n",
    "evaluate_policy(model, env, n_eval_episodes=10, deterministic=False)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "include_colab_link": true,
   "name": "Copie de Unit 1: Train your first Deep Reinforcement Learning Agent 🚀.ipynb",
   "private_outputs": true,
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.15"
  },
  "vscode": {
   "interpreter": {
    "hash": "ed7f8024e43d3b8f5ca3c5e1a8151ab4d136b3ecee1e3fd59e0766ccc55e1b10"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
